{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.TransformerRNNPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TransformerRNNPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These is a Pytorch implementation of a Transformer + RNN created by Ignacio Oguiza - oguiza@timeseriesAI.co inspired by the code created by Baurzhan Urazalinov (https://www.kaggle.com/baurzhanurazalinov).\n",
    "\n",
    "Baurzhan Urazalinov won a Kaggle competition (Parkinson's Freezing of Gait Prediction: Event detection from wearable sensor data - 2023) using the following original tensorflow code:\n",
    "\n",
    "* https://www.kaggle.com/code/baurzhanurazalinov/parkinson-s-freezing-defog-training-code\n",
    "* https://www.kaggle.com/code/baurzhanurazalinov/parkinson-s-freezing-tdcsfog-training-code\n",
    "* https://www.kaggle.com/code/baurzhanurazalinov/parkinson-s-freezing-submission-code\n",
    "\n",
    "I'd like to congratulate Baurzhan for winning this competition, and for sharing the code he used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
    "from collections import OrderedDict\n",
    "from tsai.models.layers import lin_nd_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.utils import count_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 864, 54])\n",
      "235382\n"
     ]
    }
   ],
   "source": [
    "t = torch.rand(4, 864, 54)\n",
    "encoder_layer = torch.nn.TransformerEncoderLayer(54, 6, dim_feedforward=2048, dropout=0.1, \n",
    "                                                 activation=\"relu\", layer_norm_eps=1e-05, \n",
    "                                                 batch_first=True, norm_first=False)\n",
    "print(encoder_layer(t).shape)\n",
    "print(count_parameters(encoder_layer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _TransformerRNNEncoder(nn.Module):\n",
    "    def __init__(self, \n",
    "        cell:nn.Module, # A RNN cell instance.\n",
    "        c_in:int, # Number of channels in the input tensor.\n",
    "        seq_len:int, # Number of time steps in the input tensor.\n",
    "        d_model:int, # The number of expected features in the input.\n",
    "        nhead:int, # Number of parallel attention heads (d_model will be split across nhead - each head will have dimension d_model // nhead).\n",
    "        proj_dropout:float=0.1, # Dropout probability after the projection linear layer. Default: 0.1.\n",
    "        num_encoder_layers:int=1, # Number of transformer layers in the encoder. Default: 1.\n",
    "        dim_feedforward:int=2048, # The dimension of the feedforward network model. Default: 2048.\n",
    "        dropout:float=0.1, # Transformer encoder layers dropout. Default: 0.1.\n",
    "        num_rnn_layers:int=1, # Number of RNN layers in the encoder. Default: 1.\n",
    "        bidirectional:bool=True, # If True, becomes a bidirectional RNN. Default: True.\n",
    "        ):\n",
    "        super().__init__()\n",
    "\n",
    "        # projection layer\n",
    "        self.proj_linear = nn.Linear(c_in, d_model)\n",
    "        self.proj_dropout = nn.Dropout(proj_dropout)\n",
    "\n",
    "        # transformer encoder layers\n",
    "        self.num_encoder_layers = num_encoder_layers\n",
    "        dim_feedforward = dim_feedforward or d_model\n",
    "        self.enc_layers = nn.ModuleList([TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True, ) for _ in range(num_encoder_layers)])\n",
    "\n",
    "        # rnn layers\n",
    "        self.num_rnn_layers = num_rnn_layers\n",
    "        self.rnn_layers = nn.ModuleList([cell(d_model * (1 + bidirectional) ** i, d_model * (1 + bidirectional) ** i, bidirectional=bidirectional) for i in range(num_rnn_layers)])\n",
    "        self.seq_len = seq_len\n",
    "        self.pos_encoding = nn.Parameter(torch.randn(1, self.seq_len, d_model) * 0.02)\n",
    "\n",
    "    def forward(self, x): # (batch_size, c_in, seq_len), Example shape (4, 54, 864) \n",
    "        \n",
    "        x = x.swapaxes(1, 2) # (batch_size, seq_len, c_in), Example shape (4, 864, 54)\n",
    "        batch_size = x.shape[0]\n",
    "\n",
    "        # projection layer\n",
    "        x = self.proj_linear(x) # (batch_size, seq_len, d_model), Example shape (4, 864, 320)\n",
    "        x = x + self.pos_encoding.repeat(batch_size, 1, 1)\n",
    "        x = self.proj_dropout(x)\n",
    "\n",
    "        # transformer encoder layers\n",
    "        for i in range(self.num_encoder_layers): \n",
    "            x = self.enc_layers[i](x.transpose(0, 1)).transpose(0, 1) # (batch_size, seq_len, d_model), Example shape (4, 864, 320)\n",
    "        \n",
    "        # rnn layers\n",
    "        for i in range(self.num_rnn_layers): \n",
    "            x, _ = self.rnn_layers[i](x) # (batch_size, seq_len, CFG['fog_model_dim']*2), Example shape (4, 864, 640)\n",
    "        \n",
    "        x = x.swapaxes(1, 2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 1024, 50])\n"
     ]
    }
   ],
   "source": [
    "bs = 4\n",
    "c_in = 5\n",
    "seq_len = 50\n",
    "\n",
    "encoder = _TransformerRNNEncoder(nn.LSTM, c_in=c_in, seq_len=seq_len, d_model=128, nhead=4, num_encoder_layers=1, dim_feedforward=None, proj_dropout=0.1, dropout=0.1, num_rnn_layers=3, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "print(encoder(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _TransformerRNNPlus(nn.Sequential):\n",
    "    def __init__(self, \n",
    "        c_in:int, # Number of channels in the input tensor.\n",
    "        c_out:int, # Number of output channels.\n",
    "        seq_len:int, # Number of time steps in the input tensor.\n",
    "        d:tuple=None, # int or tuple with shape of the output tensor\n",
    "        d_model:int=128, # Total dimension of the model.\n",
    "        nhead:int=16, # Number of parallel attention heads (d_model will be split across nhead - each head will have dimension d_model // nhead).\n",
    "        proj_dropout:float=0.1, # Dropout probability after the first linear layer. Default: 0.1.\n",
    "        num_encoder_layers:int=1, # Number of transformer encoder layers. Default: 1.\n",
    "        dim_feedforward:int=2048, # The dimension of the feedforward network model. Default: 2048.\n",
    "        dropout:float=0.1, # Transformer encoder layers dropout. Default: 0.1.\n",
    "        num_rnn_layers:int=1, # Number of RNN layers in the encoder. Default: 1.\n",
    "        bidirectional:bool=True, # If True, becomes a bidirectional RNN. Default: True.\n",
    "        custom_head=None, # Custom head that will be applied to the model. If None, a head with `c_out` outputs will be used. Default: None.\n",
    "        **kwargs\n",
    "        ):\n",
    "\n",
    "        backbone = _TransformerRNNEncoder(cell=self._cell, c_in=c_in, seq_len=seq_len, proj_dropout=proj_dropout, \n",
    "                                          d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, \n",
    "                                          num_rnn_layers=num_rnn_layers, bidirectional=bidirectional)\n",
    "        self.head_nf = d_model * ((1 + bidirectional) ** (num_rnn_layers))\n",
    "        if custom_head:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, c_out, seq_len, d=d, **kwargs)\n",
    "        else:\n",
    "            head = lin_nd_head(self.head_nf, c_out, seq_len=seq_len, d=d)\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "\n",
    "class TransformerRNNPlus(_TransformerRNNPlus):\n",
    "    _cell = nn.RNN\n",
    "\n",
    "class TransformerLSTMPlus(_TransformerRNNPlus):\n",
    "    _cell = nn.LSTM\n",
    "\n",
    "class TransformerGRUPlus(_TransformerRNNPlus):\n",
    "    _cell = nn.GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4])\n",
      "torch.Size([4])\n",
      "torch.Size([4])\n"
     ]
    }
   ],
   "source": [
    "bs = 4\n",
    "c_in = 5\n",
    "c_out = 1\n",
    "seq_len = 50\n",
    "d = None\n",
    "\n",
    "model = TransformerRNNPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == torch.Size([4]) \n",
    "print(model(t).shape)\n",
    "\n",
    "model = TransformerLSTMPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == torch.Size([4])\n",
    "print(model(t).shape)\n",
    "\n",
    "model = TransformerGRUPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == torch.Size([4])\n",
    "print(model(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 3])\n",
      "torch.Size([4, 3])\n",
      "torch.Size([4, 3])\n"
     ]
    }
   ],
   "source": [
    "bs = 4\n",
    "c_in = 5\n",
    "c_out = 3\n",
    "seq_len = 50\n",
    "d = None\n",
    "\n",
    "model = TransformerRNNPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == (bs, c_out)\n",
    "print(model(t).shape)\n",
    "\n",
    "model = TransformerLSTMPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == (bs, c_out)\n",
    "print(model(t).shape)\n",
    "\n",
    "model = TransformerGRUPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == (bs, c_out)\n",
    "print(model(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 50, 3])\n",
      "torch.Size([4, 50, 3])\n",
      "torch.Size([4, 50, 3])\n"
     ]
    }
   ],
   "source": [
    "bs = 4\n",
    "c_in = 5\n",
    "c_out = 3\n",
    "seq_len = 50\n",
    "d = 50\n",
    "\n",
    "model = TransformerRNNPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == (bs, d, c_out)\n",
    "print(model(t).shape)\n",
    "\n",
    "model = TransformerLSTMPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == (bs, d, c_out)\n",
    "print(model(t).shape)\n",
    "\n",
    "model = TransformerGRUPlus(c_in=c_in, c_out=c_out, seq_len=seq_len, d=d, proj_dropout=0.1, d_model=128, nhead=4, num_encoder_layers=2, dropout=0.1, num_rnn_layers=1, bidirectional=True)\n",
    "t = torch.randn(bs, c_in, seq_len)\n",
    "assert model(t).shape == (bs, d, c_out)\n",
    "print(model(t).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/078_models.TransformerRNNPlus.ipynb couldn't be saved automatically. You should save it manually 👋\n",
      "Correct notebook to script conversion! 😃\n",
      "Saturday 17/06/23 12:20:57 CEST\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
